from fastai.vision.all import *
from collections import defaultdict
import wandb

discriminator = create_head(1, 1, [20, 20])[3:]
generator = create_head(1, 1, [20, 20])[3:]

opt1 = Adam(discriminator.parameters(), 1e-3)
opt2 = Adam(generator.parameters(), 1e-3)

log = defaultdict(list)
wandb.init(project='random_codes', )
for e in range(1000):
    x_true = torch.randn(10, 1) * 2 + 2
    x_noise = torch.randn(10, 1)
    x = torch.cat([x_true, x_noise], 0)
    label = torch.cat([torch.ones_like(x_true), torch.zeros_like(x_noise)], 0)

    loss1 = F.binary_cross_entropy(torch.sigmoid(discriminator(x)), label)
    opt1.zero_grad()
    loss1.backward()
    opt1.step()

    noise = torch.rand(100, 1)
    gen = generator(noise)
    preds = discriminator(gen)
    loss2 = preds.mean()
    opt2.zero_grad()
    loss2.backward()
    opt2.step()

    log = {
        'loss1': loss1.item(),
        'loss2': loss2.item(),
    }
    if e % 100 == 0:
        log['generator'] = wandb.Histogram(gen.data.flatten().numpy())
    wandb.log(log)
